/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#include "opencl/engine.hpp"

#include <bitset>
#include <iostream>
#include <stdexcept>

#include "opencl/util.hpp"

#define DECLSPEC

#include "generic/const_def.h"
#include "generic/basepoint_mul_fold_table_2526_def.h"

#include "generic/random.h"
#include "generic/bignum.h"
#include "generic/field.h"
#include "generic/point.h"

#include "kernel.cl.h"

#define MIDSTATE_ITERATIONS 8192

namespace opencl {

namespace {

constexpr size_t DEFAULT_BLOCK_SIZE = 512;
constexpr size_t DEFAULT_INV_BATCH_SIZE = 256;
constexpr size_t BLOCKS_PER_MP = 256;

namespace host {

inline const Bn256 addend_mask = BN_ADDEND_MASK;

inline const Bn256 l_3 = BN_L_3;
inline const Bn256 l_4 = BN_L_4;
inline const Bn256 l_5 = BN_L_5;
inline const Bn256 l_6 = BN_L_6;
inline const Bn256 l_7 = BN_L_7;

} // namespace host

auto print_key = [](auto& key) {
    for(auto i = 0; i < 16; ++i) {
        printf("%d, ", key[i]);
    }
    printf("\n");
};

static const AffineNielsPoint basepoint_mul_fold_table[256] = ED25519_BASEPOINT_MUL_FOLD_TABLE_2526_DEF;

class OpenclGeneratorCtx final : public ::GeneratorCtx {
public:
    size_t batch_size() override { return batch_size_; }
    const uint8_t* address(size_t i) override { return addresses_.get() + 16 * i; }
    const uint8_t* skey(size_t i) override {
        auto& midstate = midstates_[i / 2];
        auto& addend = addends_[seq_];
        uint8_t to8 = (8 - (midstate[0] % 8)) & 0x07;

        Bn256 chosen_l;

        if(to8 == host::l_3[0] % 8) {
            bn_copy(chosen_l, host::l_3);
        } else if(to8 == host::l_4[0] % 8) {
            bn_copy(chosen_l, host::l_4);
        } else if(to8 == host::l_5[0] % 8) {
            bn_copy(chosen_l, host::l_5);
        } else if(to8 == host::l_6[0] % 8) {
            bn_copy(chosen_l, host::l_6);
        } else if(to8 == host::l_7[0] % 8) {
            bn_copy(chosen_l, host::l_7);
        } else {
            throw std::runtime_error{"Unable to choose L * n"};
        }

        Bn256 sk;
        bn_add(sk, midstate, chosen_l);
        if(i % 2 == 0)
            bn_add(sk, sk, addend);
        else
            bn_sub(sk, sk, addend);

        // TODO : correct conversion
        memcpy(skey_, sk, 32);
        return skey_;
    }

    std::shared_ptr<const Bn256[]> midstates_;
    std::shared_ptr<const Bn256[]> addends_;
    std::shared_ptr<uint8_t[]> addresses_;
    size_t batch_size_;
    uint64_t seq_;
    uint8_t skey_[32];
};

class OpenclGenerator final : public Generator {
public:
    OpenclGenerator(const GeneratorParams& params)
    : block_size_{params.block_size}
    , n_blocks_{params.n_blocks}
    , batch_size_{block_size_ * n_blocks_}
    , inv_batch_size_{params.inv_batch_size}
    , device_{get_gpu(params.device)}
    , ctx_{{device_}}
    , queue_{ctx_, device_}
    , program_{make_program(ctx_, device_, kernel_cl, kernel_cl_len, "-D BLOCK_SIZE=" + std::to_string(block_size_))}
    , precompute_addends_kernel_{program_, "precompute_addends_kernel"}
    , compute_midstate_kernel_0_{program_, "compute_midstate_kernel_0"}
    , compute_midstate_kernel_1_{program_, "compute_midstate_kernel_1"}
    , batch_invert_kernel_0_{program_, "batch_invert_kernel_0"}
    , batch_invert_kernel_1_{program_, "batch_invert_kernel_1"}
    , batch_invert_kernel_2_{program_, "batch_invert_kernel_2"}
    , compute_uv_kernel_0_{program_, "compute_uv_kernel_0"}
    , compute_uv_kernel_1_{program_, "compute_uv_kernel_1"}
    , compute_address_kernel_{program_, "compute_address_kernel"}
    , midstate_ypx_{make_device_buffer<uint32_t>(ctx_, FE_SIZE * batch_size_)}
    , midstate_ymx_{make_device_buffer<uint32_t>(ctx_, FE_SIZE * batch_size_)}
    , midstate_xy_{make_device_buffer<uint32_t>(ctx_, FE_SIZE * batch_size_)}
    , seeds_{ctx_, batch_size_}
    , midstate_scalars_{ctx_, batch_size_}
    , basepoint_mul_fold_table_{ctx_, 256}
    , addend_sum_{ctx_, 1}
    , midstate_iterations_buf_{ctx_, 1}
    , us_{make_device_buffer<uint32_t>(ctx_, 2 * FE_SIZE * batch_size_)}
    , us_inv_{make_device_buffer<uint32_t>(ctx_, 2 * FE_SIZE * batch_size_)}
    , vs_{make_device_buffer<uint32_t>(ctx_, 2 * FE_SIZE * batch_size_)}
    , ts_{make_device_buffer<uint32_t>(ctx_, 2 * FE_SIZE * batch_size_)}
    , addresses_{make_device_buffer<uint8_t>(ctx_, 16 * 2 * batch_size_)}
    , seq_buf_{make_device_buffer<uint64_t>(ctx_, 1)}
    , inv_batch_size_buf_{make_device_buffer<size_t>(ctx_, 1)}
    {
        rng_init(&rng_, params.seed, params.seq);
        copy_to_device(inv_batch_size_buf_, &inv_batch_size_, 1, queue_);
        *midstate_iterations_buf_.host() = MIDSTATE_ITERATIONS;
        midstate_iterations_buf_.copy_to_device(queue_);
    }

private:
    // Generator
    void produce(::GeneratorCtx& raw_ctx) override {
        if(seq_ == 0) init_midstate();
        compute_addresses(dynamic_cast<OpenclGeneratorCtx&>(raw_ctx));
    }

private:
    void init_addends()
    {
        auto addend_scalars = std::make_unique<Bn256[]>(MIDSTATE_ITERATIONS);
        DualBuffer<Bn256> addend_scalars_buffer{ctx_, MIDSTATE_ITERATIONS};

        memcpy(basepoint_mul_fold_table_.host(), basepoint_mul_fold_table, sizeof(basepoint_mul_fold_table));
        basepoint_mul_fold_table_.copy_to_device(queue_);

        Bn256 a_sum;
        bn_zero(a_sum);

        for(auto i = 0; i < MIDSTATE_ITERATIONS; ++i) {
            Bn256 a;
            rng_fill_bytes(&rng_, a, 32);
            bn_and(a, a, host::addend_mask);
            bn_add(a_sum, a_sum, a);

            bn_copy(addend_scalars[i], a);
            bn_copy(addend_scalars_buffer.host()[i], a);
        }

        addends_ypx_ = make_device_buffer<uint32_t>(ctx_, FE_SIZE * MIDSTATE_ITERATIONS);
        addends_ymx_ = make_device_buffer<uint32_t>(ctx_, FE_SIZE * MIDSTATE_ITERATIONS);
        addends_xy_ = make_device_buffer<uint32_t>(ctx_, FE_SIZE * MIDSTATE_ITERATIONS);

        addend_scalars_shared_ = std::move(addend_scalars);
        bn_copy(*addend_sum_.host(), a_sum);

        addend_scalars_buffer.copy_to_device(queue_);
        addend_sum_.copy_to_device(queue_);

        bind_kernel(
            precompute_addends_kernel_,
            addends_ypx_,
            addends_ymx_,
            addends_xy_,
            addend_scalars_buffer.device(),
            basepoint_mul_fold_table_.device(),
            midstate_iterations_buf_.device()
        );

        run_kernel(precompute_addends_kernel_, MIDSTATE_ITERATIONS / block_size_, block_size_, queue_);

        addends_initialized_ = true;
    }

    void init_midstate()
    {
        if(!addends_initialized_) init_addends();

        midstate_scalars_shared_ = nullptr;

        rng_fill_bytes(&rng_, reinterpret_cast<uint8_t*>(seeds_.host()), seeds_.size() * sizeof(uint64_t));

        seeds_.copy_to_device(queue_);
        bind_kernel(
            compute_midstate_kernel_0_,
            midstate_ypx_,
            midstate_ymx_,
            us_,
            midstate_scalars_.device(),
            seeds_.device(),
            addend_sum_.device(),
            basepoint_mul_fold_table_.device()
        );

        run_kernel(compute_midstate_kernel_0_, n_blocks_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_0_, us_inv_, us_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_0_, n_blocks_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_1_, us_inv_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_1_, n_blocks_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_2_, us_inv_, us_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_2_, n_blocks_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(
            compute_midstate_kernel_1_,
            midstate_ypx_,
            midstate_ymx_,
            midstate_xy_,
            us_inv_
        );
        run_kernel(compute_midstate_kernel_1_, n_blocks_, block_size_, queue_);
        midstate_scalars_.copy_to_host(queue_);
    }

    void compute_addresses(OpenclGeneratorCtx& ctx)
    {
        copy_to_device(seq_buf_, &seq_, 1, queue_);

        bind_kernel(
            compute_uv_kernel_0_,
            ts_, // zs
            us_inv_, // efs
            midstate_xy_,
            addends_xy_,
            seq_buf_
        );
        run_kernel(compute_uv_kernel_0_, n_blocks_, block_size_, queue_);

        bind_kernel(
            compute_uv_kernel_1_,
            us_,
            vs_,
            ts_, // zs
            us_inv_, // efs
            midstate_ypx_,
            midstate_ymx_,
            addends_ypx_,
            addends_ymx_,
            seq_buf_
        );
        run_kernel(compute_uv_kernel_1_, n_blocks_2_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_0_, us_inv_, us_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_0_, n_blocks_2_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_1_, us_inv_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_1_, n_blocks_2_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(batch_invert_kernel_2_, us_inv_, us_, inv_batch_size_buf_);
        run_kernel(batch_invert_kernel_2_, n_blocks_2_ / inv_batch_size_, block_size_, queue_);

        bind_kernel(
            compute_address_kernel_,
            addresses_,
            us_inv_,
            vs_
        );
        run_kernel(compute_address_kernel_, n_blocks_2_, block_size_, queue_);

        if(!ctx.addresses_) {
            ctx.addresses_ = std::make_unique<uint8_t[]>(batch_size_ * 16 * 2);
        }

        ctx.addends_ = addend_scalars_shared_;
        ctx.batch_size_ = batch_size_ * 2;
        ctx.seq_ = seq_;
        memcpy_to_host(ctx.addresses_.get(), addresses_, 16 * 2 * batch_size_, queue_);

        if(++seq_ >= MIDSTATE_ITERATIONS) seq_ = 0;

        OPENCL_CHECK(queue_.flush());
        OPENCL_CHECK(queue_.finish());

        if(!midstate_scalars_shared_) {
            auto midstate_scalars_tmp = std::make_unique<Bn256[]>(batch_size_);
            memcpy(midstate_scalars_tmp.get(), midstate_scalars_.host(), batch_size_ * sizeof(Bn256));
            midstate_scalars_shared_ = std::move(midstate_scalars_tmp);
        }
        ctx.midstates_ = midstate_scalars_shared_;
    }

private:
    Rng rng_;

    const size_t n_blocks_;
    const size_t block_size_;
    const size_t batch_size_;
    const size_t n_blocks_2_{n_blocks_ * 2};
    const size_t inv_batch_size_;

    cl::Device device_;
    cl::Context ctx_;
    cl::CommandQueue queue_;
    cl::Program program_;

    cl::Kernel precompute_addends_kernel_;
    cl::Kernel compute_midstate_kernel_0_;
    cl::Kernel compute_midstate_kernel_1_;
    cl::Kernel batch_invert_kernel_0_;
    cl::Kernel batch_invert_kernel_1_;
    cl::Kernel batch_invert_kernel_2_;
    cl::Kernel compute_uv_kernel_0_;
    cl::Kernel compute_uv_kernel_1_;
    cl::Kernel compute_address_kernel_;

    cl::Buffer addends_ypx_;
    cl::Buffer addends_ymx_;
    cl::Buffer addends_xy_;
    cl::Buffer midstate_ypx_;
    cl::Buffer midstate_ymx_;
    cl::Buffer midstate_xy_;
    cl::Buffer addresses_;

    DualBuffer<uint64_t> seeds_;
    DualBuffer<Bn256> midstate_scalars_;
    DualBuffer<AffineNielsPoint> basepoint_mul_fold_table_;
    DualBuffer<Bn256> addend_sum_;
    DualBuffer<size_t> midstate_iterations_buf_;

    std::shared_ptr<const Bn256[]> midstate_scalars_shared_;
    std::shared_ptr<const Bn256[]> addend_scalars_shared_;

    cl::Buffer us_;
    cl::Buffer us_inv_;
    cl::Buffer vs_;
    cl::Buffer ts_;
    cl::Buffer seq_buf_;
    cl::Buffer inv_batch_size_buf_;

    uint64_t seq_{0};

    bool addends_initialized_{false};
};

class OpenclEngine final : public Engine
{
public:
    const char* name() override { return "opencl"; }

    void print_info() override {
        auto gpus = get_gpus();
        printf("OpenCL: %d device(s)\n", gpus.size());
        size_t i = 0;
        for(auto& gpu : gpus) {
            std::cout << "Device #" << i << "(" << gpu.getInfo<CL_DEVICE_NAME>() << ")" << std::endl;
            printf("\tTotal global memory:        %.2fG\n", float(gpu.getInfo<CL_DEVICE_GLOBAL_MEM_SIZE>()) / 1024 / 1024 / 1024);
            printf("\tMaximum block size:         %d\n", gpu.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
            printf("\tClock Rate:                 %dMHz\n", gpu.getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>());
            printf("\tNumber of multiprocessors:  %d\n", gpu.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>());
            ++i;
        }
    }

    std::string device_name(size_t i) override {
        auto gpu = get_gpu(i);
        return gpu.getInfo<CL_DEVICE_NAME>();
    }

    void fill_params(GeneratorParams& params) override {
        auto gpu = get_gpu(params.device);

        if(!params.block_size)
            params.block_size = std::min<size_t>(DEFAULT_BLOCK_SIZE, gpu.getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());

        if(std::bitset<sizeof(params.block_size) * 8>(params.block_size).count() != 1)
            throw std::runtime_error{"Block size must be a power of two"};

        if(!params.n_blocks)
            params.n_blocks = BLOCKS_PER_MP * gpu.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();

        if(!params.inv_batch_size)
            params.inv_batch_size = std::min<size_t>(DEFAULT_INV_BATCH_SIZE, params.n_blocks * 2);

        if(params.inv_batch_size > params.n_blocks * 2)
            throw std::runtime_error{"Invert batch size is too large"};
    }

    std::unique_ptr<Generator> make_generator(const GeneratorParams& params) override {
        return std::make_unique<OpenclGenerator>(params);
    }

    std::unique_ptr<GeneratorCtx> make_generator_context() override {
        return std::make_unique<OpenclGeneratorCtx>();
    }
};

std::unique_ptr<OpenclEngine> static_engine = std::make_unique<OpenclEngine>();

} // anonymous namespace

Engine& get_engine()
{
    return *static_engine.get();
}

} // namespace opencl